# Copyright (c) 2020-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import pytest

from asserts import assert_gpu_and_cpu_are_equal_collect, assert_gpu_and_cpu_are_equal_sql, assert_gpu_fallback_collect, assert_gpu_sql_fallback_collect
from data_gen import *
from marks import *
from pyspark.sql.types import *
from pyspark.sql.types import DateType, TimestampType, NumericType
from pyspark.sql.window import Window
import pyspark.sql.functions as f
from spark_session import is_before_spark_320, is_databricks113_or_later, is_databricks133_or_later, is_spark_350_or_later, spark_version, with_cpu_session
import warnings

# mark this test as ci_1 for mvn verify sanity check in pre-merge CI
pytestmark = [pytest.mark.premerge_ci_1]

_grpkey_longs_with_no_nulls = [
    ('a', RepeatSeqGen(LongGen(nullable=False), length=20)),
    ('b', IntegerGen()),
    ('c', UniqueLongGen())]

_grpkey_longs_with_nulls = [
    ('a', RepeatSeqGen(LongGen(nullable=(True, 10.0)), length=20)),
    ('b', IntegerGen()),
    ('c', UniqueLongGen())]

_grpkey_longs_with_dates = [
    ('a', RepeatSeqGen(LongGen(), length=2048)),
    ('b', DateGen(nullable=False, start=date(year=2020, month=1, day=1), end=date(year=2020, month=12, day=31))),
    ('c', UniqueLongGen())]

_grpkey_longs_with_nullable_dates = [
    ('a', RepeatSeqGen(LongGen(nullable=False), length=20)),
    ('b', DateGen(nullable=(True, 5.0), start=date(year=2020, month=1, day=1), end=date(year=2020, month=12, day=31))),
    ('c', UniqueLongGen())]

_grpkey_longs_with_timestamps = [
    ('a', RepeatSeqGen(LongGen(), length=2048)),
    ('b', TimestampGen(nullable=False)),
    ('c', IntegerGen())]

_grpkey_longs_with_nullable_timestamps = [
    ('a', RepeatSeqGen(LongGen(nullable=False), length=20)),
    ('b', TimestampGen(nullable=(True, 5.0))),
    ('c', IntegerGen())]

_grpkey_longs_with_decimals = [
    ('a', RepeatSeqGen(LongGen(nullable=False), length=20)),
    ('b', DecimalGen(precision=18, scale=3, nullable=False)),
    ('c', UniqueLongGen())]

_grpkey_longs_with_nullable_decimals = [
    ('a', RepeatSeqGen(LongGen(nullable=(True, 10.0)), length=20)),
    ('b', DecimalGen(precision=18, scale=10, nullable=True)),
    ('c', UniqueLongGen())]

_grpkey_longs_with_nullable_larger_decimals = [
    ('a', RepeatSeqGen(LongGen(nullable=(True, 10.0)), length=20)),
    ('b', DecimalGen(precision=23, scale=10, nullable=True)),
    ('c', UniqueLongGen())]

_grpkey_longs_with_nullable_largest_decimals = [
    ('a', RepeatSeqGen(LongGen(nullable=(True, 10.0)), length=20)),
    ('b', DecimalGen(precision=38, scale=2, nullable=True)),
    ('c', DecimalGen(precision=38, scale=2, nullable=True))]

_grpkey_longs_with_nullable_floats = [
    ('a', RepeatSeqGen(LongGen(nullable=(True, 10.0)), length=20)),
    ('b', FloatGen(nullable=True)),
    ('c', IntegerGen(nullable=True))]

_grpkey_longs_with_nullable_doubles = [
    ('a', RepeatSeqGen(LongGen(nullable=(True, 10.0)), length=20)),
    ('b', DoubleGen(nullable=True)),
    ('c', IntegerGen(nullable=True))]

_grpkey_decimals_with_nulls = [
    ('a', RepeatSeqGen(LongGen(nullable=(True, 10.0)), length=20)),
    ('b', IntegerGen()),
    ('c', DecimalGen(precision=8, scale=3, nullable=True))]

_grpkey_byte_with_nulls = [
    ('a', RepeatSeqGen(int_gen, length=20)),
    # restrict the values generated by min_val/max_val not to be overflow when calculating
    ('b', ByteGen(nullable=True, min_val=-98, max_val=98, special_cases=[])),
    ('c', UniqueLongGen())]

_grpkey_short_with_nulls = [
    ('a', RepeatSeqGen(int_gen, length=20)),
    # restrict the values generated by min_val/max_val not to be overflow when calculating
    ('b', ShortGen(nullable=True, min_val=-32700, max_val=32700, special_cases=[])),
    ('c', UniqueLongGen())]

_grpkey_int_with_nulls = [
    ('a', RepeatSeqGen(int_gen, length=20)),
    # restrict the values generated by min_val/max_val not to be overflow when calculating
    ('b', IntegerGen(nullable=True, min_val=-2147483000, max_val=2147483000, special_cases=[])),
    ('c', UniqueLongGen())]

_grpkey_long_with_nulls = [
    ('a', RepeatSeqGen(int_gen, length=20)),
    # restrict the values generated by min_val/max_val not to be overflow when calculating
    ('b', LongGen(nullable=True, min_val=-9223372036854775000, max_val=9223372036854775000, special_cases=[])),
    ('c', UniqueLongGen())]

_grpkey_date_with_nulls = [
    ('a', RepeatSeqGen(int_gen, length=20)),
    ('b', DateGen(nullable=(True, 5.0), start=date(year=2020, month=1, day=1), end=date(year=2020, month=12, day=31))),
    ('c', UniqueLongGen())]

_grpkey_byte_with_nulls_with_overflow = [
    ('a', IntegerGen()),
    ('b', ByteGen(nullable=True))]

_grpkey_short_with_nulls_with_overflow = [
    ('a', IntegerGen()),
    ('b', ShortGen(nullable=True))]

_grpkey_int_with_nulls_with_overflow = [
    ('a', IntegerGen()),
    ('b', IntegerGen(nullable=True))]

_grpkey_long_with_nulls_with_overflow = [
    ('a', IntegerGen()),
    ('b', LongGen(nullable=True))]

part_and_order_gens = [long_gen, DoubleGen(special_cases=[]),
        string_gen, boolean_gen, timestamp_gen, DecimalGen(precision=18, scale=1),
        DecimalGen(precision=38, scale=1)]

running_part_and_order_gens = [long_gen, DoubleGen(special_cases=[]),
        string_gen, byte_gen, timestamp_gen, DecimalGen(precision=18, scale=1),
        DecimalGen(precision=38, scale=1)]

lead_lag_data_gens = [long_gen, DoubleGen(special_cases=[]),
        boolean_gen, timestamp_gen, string_gen, DecimalGen(precision=18, scale=3),
        DecimalGen(precision=38, scale=4),
        StructGen(children=[
            ['child_int', IntegerGen()],
            ['child_time', DateGen()],
            ['child_string', StringGen()]
        ])]

_float_conf = {'spark.rapids.sql.variableFloatAgg.enabled': 'true',
                       'spark.rapids.sql.castStringToFloat.enabled': 'true'
                      }

@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [SetValuesGen(t, [math.nan, None]) for t in [FloatType(), DoubleType()]], ids=idfn)
def test_float_window_min_max_all_nans(data_gen):
  w = Window().partitionBy('a')
  assert_gpu_and_cpu_are_equal_collect(
      lambda spark: two_col_df(spark, byte_gen, data_gen)
          .withColumn("min_b", f.min('a').over(w))
          .withColumn("max_b", f.max('a').over(w))
  )


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
@ignore_order
@pytest.mark.parametrize('data_gen', [decimal_gen_128bit], ids=idfn)
def test_decimal128_count_window(data_gen):
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: three_col_df(spark, byte_gen, UniqueLongGen(), data_gen),
        'window_agg_table',
        'select '
        ' count(c) over '
        '   (partition by a order by b asc '
        '      rows between 2 preceding and 10 following) as count_c_asc '
        'from window_agg_table')


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
@ignore_order
@pytest.mark.parametrize('data_gen', [decimal_gen_128bit], ids=idfn)
def test_decimal128_count_window_no_part(data_gen):
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: two_col_df(spark, UniqueLongGen(), data_gen),
        'window_agg_table',
        'select '
        ' count(b) over '
        '   (order by a asc '
        '      rows between 2 preceding and 10 following) as count_b_asc '
        'from window_agg_table')


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
@ignore_order
@pytest.mark.parametrize('data_gen', decimal_gens, ids=idfn)
def test_decimal_sum_window(data_gen):
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: three_col_df(spark, byte_gen, UniqueLongGen(), data_gen),
        'window_agg_table',
        'select '
        ' sum(c) over '
        '   (partition by a order by b asc '
        '      rows between 2 preceding and 10 following) as sum_c_asc '
        'from window_agg_table')


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
@ignore_order
@pytest.mark.parametrize('data_gen', decimal_gens, ids=idfn)
def test_decimal_sum_window_no_part(data_gen):
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: two_col_df(spark, UniqueLongGen(), data_gen),
        'window_agg_table',
        'select '
        ' sum(b) over '
        '   (order by a asc '
        '      rows between 2 preceding and 10 following) as sum_b_asc '
        'from window_agg_table')


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
@ignore_order
@pytest.mark.parametrize('data_gen', decimal_gens, ids=idfn)
def test_decimal_running_sum_window(data_gen):
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: three_col_df(spark, byte_gen, UniqueLongGen(), data_gen),
        'window_agg_table',
        'select '
        ' sum(c) over '
        '   (partition by a order by b asc '
        '      rows between UNBOUNDED PRECEDING AND CURRENT ROW) as sum_c_asc '
        'from window_agg_table',
        conf = {'spark.rapids.sql.batchSizeBytes': '100'})


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
@ignore_order
@pytest.mark.parametrize('data_gen', decimal_gens, ids=idfn)
def test_decimal_running_sum_window_no_part(data_gen):
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: two_col_df(spark, UniqueLongGen(), data_gen),
        'window_agg_table',
        'select '
        ' sum(b) over '
        '   (order by a asc '
        '      rows between UNBOUNDED PRECEDING AND CURRENT ROW) as sum_b_asc '
        'from window_agg_table',
        conf = {'spark.rapids.sql.batchSizeBytes': '100'})

@pytest.mark.xfail(reason="[UNSUPPORTED] Ranges over order by byte column overflow "
                          "(https://github.com/NVIDIA/spark-rapids/pull/2020#issuecomment-838127070)")
@ignore_order
@pytest.mark.parametrize('data_gen', [_grpkey_byte_with_nulls_with_overflow], ids=idfn)
def test_window_aggs_for_ranges_numeric_byte_overflow(data_gen):
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark, data_gen, length=2048),
        "window_agg_table",
        'select '
        ' sum(b) over '
        '   (partition by a order by b asc  '
        '      range between 127 preceding and 127 following) as sum_c_asc, '
        'from window_agg_table',
        conf={'spark.rapids.sql.window.range.byte.enabled': True})


@pytest.mark.xfail(reason="[UNSUPPORTED] Ranges over order by short column overflow "
                          "(https://github.com/NVIDIA/spark-rapids/pull/2020#issuecomment-838127070)")
@ignore_order
@pytest.mark.parametrize('data_gen', [_grpkey_short_with_nulls_with_overflow], ids=idfn)
def test_window_aggs_for_ranges_numeric_short_overflow(data_gen):
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark, data_gen, length=2048),
        "window_agg_table",
        'select '
        ' sum(b) over '
        '   (partition by a order by b asc  '
        '      range between 32767 preceding and 32767 following) as sum_c_asc, '
        'from window_agg_table',
        conf={'spark.rapids.sql.window.range.short.enabled': True})


@pytest.mark.xfail(reason="[UNSUPPORTED] Ranges over order by int column overflow "
                          "(https://github.com/NVIDIA/spark-rapids/pull/2020#issuecomment-838127070)")
@ignore_order
@pytest.mark.parametrize('data_gen', [_grpkey_int_with_nulls_with_overflow], ids=idfn)
def test_window_aggs_for_ranges_numeric_int_overflow(data_gen):
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark, data_gen, length=2048),
        "window_agg_table",
        'select '
        ' sum(b) over '
        '   (partition by a order by b asc  '
        '      range between 2147483647 preceding and 2147483647 following) as sum_c_asc, '
        'from window_agg_table')


@pytest.mark.xfail(reason="[UNSUPPORTED] Ranges over order by long column overflow "
                          "(https://github.com/NVIDIA/spark-rapids/pull/2020#issuecomment-838127070)")
@ignore_order
@pytest.mark.parametrize('data_gen', [_grpkey_long_with_nulls_with_overflow], ids=idfn)
def test_window_aggs_for_ranges_numeric_long_overflow(data_gen):
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark, data_gen, length=2048),
        "window_agg_table",
        'select '
        ' sum(b) over '
        '   (partition by a order by b asc  '
        '      range between 9223372036854775807 preceding and 9223372036854775807 following) as sum_c_asc, '
        'from window_agg_table')


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
# In a distributed setup the order of the partitions returned might be different, so we must ignore the order
# but small batch sizes can make sort very slow, so do the final order by locally
@ignore_order(local=True)
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
@pytest.mark.parametrize('data_gen', [
                                      _grpkey_byte_with_nulls,
                                      _grpkey_short_with_nulls,
                                      _grpkey_int_with_nulls,
                                      _grpkey_long_with_nulls,
                                      _grpkey_date_with_nulls,
                                    ], ids=idfn)
def test_window_aggs_for_range_numeric_date(data_gen, batch_size):
    conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
            'spark.rapids.sql.window.range.byte.enabled': True,
            'spark.rapids.sql.window.range.short.enabled': True}
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark, data_gen, length=2048),
        'window_agg_table',
        'select '
        ' sum(c) over '
        '   (partition by a order by b asc  '
        '      range between 1 preceding and 3 following) as sum_c_asc, '
        ' avg(c) over '
        '   (partition by a order by b asc  '
        '       range between 1 preceding and 3 following) as avg_b_asc, '
        ' max(c) over '
        '   (partition by a order by b asc '
        '       range between 1 preceding and 3 following) as max_b_desc, '
        ' min(c) over '
        '   (partition by a order by b asc  '
        '       range between 1 preceding and 3 following) as min_b_asc, '
        ' count(1) over '
        '   (partition by a order by b asc  '
        '       range between  CURRENT ROW and UNBOUNDED following) as count_1_asc, '
        ' count(c) over '
        '   (partition by a order by b asc  '
        '       range between  CURRENT ROW and UNBOUNDED following) as count_b_asc, '
        ' avg(c) over '
        '   (partition by a order by b asc  '
        '       range between UNBOUNDED preceding and CURRENT ROW) as avg_b_unbounded, '
        ' sum(c) over '
        '   (partition by a order by b asc  '
        '       range between UNBOUNDED preceding and CURRENT ROW) as sum_b_unbounded, '
        ' max(c) over '
        '   (partition by a order by b asc  '
        '       range between UNBOUNDED preceding and UNBOUNDED following) as max_b_unbounded '
        'from window_agg_table ',
        conf = conf)


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
# In a distributed setup the order of the partitions returned might be different, so we must ignore the order
# but small batch sizes can make sort very slow, so do the final order by locally
@ignore_order(local=True)
@datagen_overrides(seed=0, reason="https://github.com/NVIDIA/spark-rapids/issues/9682")
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
@pytest.mark.parametrize('data_gen', [_grpkey_longs_with_no_nulls,
                                      _grpkey_longs_with_nulls,
                                      _grpkey_longs_with_dates,
                                      _grpkey_longs_with_nullable_dates,
                                      _grpkey_longs_with_decimals,
                                      _grpkey_longs_with_nullable_decimals,
                                      _grpkey_longs_with_nullable_larger_decimals,
                                      _grpkey_decimals_with_nulls], ids=idfn)
def test_window_aggs_for_rows(data_gen, batch_size):
    conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
            'spark.rapids.sql.castFloatToDecimal.enabled': True}
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark : gen_df(spark, data_gen, length=2048),
        "window_agg_table",
        'select '
        ' sum(c) over '
        '   (partition by a order by b,c asc rows between 1 preceding and 1 following) as sum_c_asc, '
        ' max(c) over '
        '   (partition by a order by b desc, c desc rows between 2 preceding and 1 following) as max_c_desc, '
        ' min(c) over '
        '   (partition by a order by b,c rows between 2 preceding and current row) as min_c_asc, '
        ' count(1) over '
        '   (partition by a order by b,c rows between UNBOUNDED preceding and UNBOUNDED following) as count_1, '
        ' count(c) over '
        '   (partition by a order by b,c rows between UNBOUNDED preceding and UNBOUNDED following) as count_c, '
        ' avg(c) over '
        '   (partition by a order by b,c rows between UNBOUNDED preceding and UNBOUNDED following) as avg_c, '
        ' rank() over '
        '   (partition by a order by b,c rows between UNBOUNDED preceding and CURRENT ROW) as rank_val, '
        ' dense_rank() over '
        '   (partition by a order by b,c rows between UNBOUNDED preceding and CURRENT ROW) as dense_rank_val, '
        ' percent_rank() over '
        '   (partition by a order by b,c rows between UNBOUNDED preceding and CURRENT ROW) as percent_rank_val, '
        ' row_number() over '
        '   (partition by a order by b,c rows between UNBOUNDED preceding and CURRENT ROW) as row_num '
        'from window_agg_table ',
        conf = conf)


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
@ignore_order(local=True)
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn)
@pytest.mark.parametrize('data_gen', [
    [('grp', RepeatSeqGen(int_gen, length=20)),  # Grouping column.
     ('ord', UniqueLongGen(nullable=True)),      # Order-by column (after cast to STRING).
     ('agg', IntegerGen())]                      # Aggregation column.
], ids=idfn)
def test_range_windows_with_string_order_by_column(data_gen, batch_size):
    """
    Tests that RANGE window functions can be used with STRING order-by columns.
    """
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark, data_gen, length=2048),
        'window_agg_table',
        'SELECT '
        ' ROW_NUMBER() OVER '
        '   (PARTITION BY grp ORDER BY CAST(ord AS STRING) ASC ) as row_num_asc, '
        ' RANK() OVER '
        '   (PARTITION BY grp ORDER BY CAST(ord AS STRING) DESC ) as rank_desc, '
        ' DENSE_RANK() OVER '
        '   (PARTITION BY grp ORDER BY CAST(ord AS STRING) ASC ) as dense_rank_asc, '
        ' COUNT(1) OVER '
        '   (PARTITION BY grp ORDER BY CAST(ord AS STRING) ASC ) as count_1_asc_default, '
        ' COUNT(agg) OVER '
        '   (PARTITION BY grp ORDER BY CAST(ord AS STRING) DESC ) as count_desc_default, '
        ' SUM(agg) OVER '
        '   (PARTITION BY grp ORDER BY CAST(ord AS STRING) ASC  '
        '       RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_asc_UNB_to_CURRENT, '
        ' MIN(agg) OVER '
        '   (PARTITION BY grp ORDER BY CAST(ord AS STRING) DESC  '
        '       RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as min_desc_UNB_to_CURRENT, '
        ' MAX(agg) OVER '
        '   (PARTITION BY grp ORDER BY CAST(ord AS STRING) ASC  '
        '       RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) as max_asc_CURRENT_to_UNB, '
        ' COUNT(1) OVER '
        '   (PARTITION BY grp ORDER BY CAST(ord AS STRING) DESC  '
        '       RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) as count_1_desc_CURRENT_to_UNB, '
        ' COUNT(1) OVER '
        '   (PARTITION BY grp ORDER BY CAST(ord AS STRING) ASC  '
        '       RANGE BETWEEN CURRENT ROW AND CURRENT ROW) as count_1_asc_CURRENT_to_CURRENT, '
        ' COUNT(1) OVER '
        '   (PARTITION BY grp ORDER BY CAST(ord AS STRING) ASC  '
        '       RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as count_1_asc_UNB_to_UNB, '
        ' COUNT(1) OVER '
        '   (PARTITION BY grp ORDER BY CAST(ord AS STRING) DESC  '
        '       RANGE BETWEEN CURRENT ROW AND CURRENT ROW) as count_1_desc_CURRENT_to_CURRENT '
        ' FROM window_agg_table ',
        conf={'spark.rapids.sql.batchSizeBytes': batch_size})

# This is for aggregations that work with the optimized unbounded to unbounded window optimization.
# They don't need to be batched specially, but it only works if all of the aggregations can support this.
# the order returned should be consistent because the data ends up in a single task (no partitioning)
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
@pytest.mark.parametrize('b_gen', all_basic_gens + [decimal_gen_32bit, decimal_gen_128bit], ids=meta_idfn('data:'))
def test_window_batched_unbounded_no_part(b_gen, batch_size):
    conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
            'spark.rapids.sql.castFloatToDecimal.enabled': True}
    query_parts = ['min(b) over (order by a rows between UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as min_col',
            'max(b) over (order by a rows between UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as max_col']

    assert_gpu_and_cpu_are_equal_sql(
        lambda spark : two_col_df(spark, UniqueLongGen(), b_gen, length=1024 * 14),
        "window_agg_table",
        'select ' +
        ', '.join(query_parts) +
        ' from window_agg_table ',
        validate_execs_in_gpu_plan = ['GpuCachedDoublePassWindowExec'],
        conf = conf)

@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
@pytest.mark.parametrize('b_gen', all_basic_gens + [decimal_gen_32bit, decimal_gen_128bit], ids=meta_idfn('data:'))
def test_window_batched_unbounded(b_gen, batch_size):
    conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
            'spark.rapids.sql.castFloatToDecimal.enabled': True}
    query_parts = ['min(b) over (order by a rows between UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as min_col',
            'max(b) over (partition by a % 2 order by a rows between UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) as max_col']

    assert_gpu_and_cpu_are_equal_sql(
        lambda spark : two_col_df(spark, UniqueLongGen(), b_gen, length=1024 * 14),
        "window_agg_table",
        'select ' +
        ', '.join(query_parts) +
        ' from window_agg_table ',
        validate_execs_in_gpu_plan = ['GpuCachedDoublePassWindowExec'],
        conf = conf)


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
# This is for aggregations that work with a running window optimization. They don't need to be batched
# specially, but it only works if all of the aggregations can support this.
# the order returned should be consistent because the data ends up in a single task (no partitioning)
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
@pytest.mark.parametrize('b_gen', all_basic_gens + [decimal_gen_32bit, decimal_gen_128bit], ids=meta_idfn('data:'))
def test_rows_based_running_window_unpartitioned(b_gen, batch_size):
    conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
            'spark.rapids.sql.castFloatToDecimal.enabled': True}
    query_parts = ['row_number() over (order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as row_num',
            'rank() over (order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as rank_val',
            'dense_rank() over (order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as dense_rank_val',
            'count(b) over (order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as count_col',
            'min(b) over (order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as min_col',
            'max(b) over (order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as max_col',
            'FIRST(b) OVER (ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS first_keep_nulls',
            'FIRST(b, TRUE) OVER (ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS first_ignore_nulls',
            'NTH_VALUE(b, 1) OVER (ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS nth_1_keep_nulls',
            'LAST(b) OVER (ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS last_keep_nulls',
            'LAST(b, TRUE) OVER (ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS last_ignore_nulls',]

    if isinstance(b_gen.data_type, NumericType) and not isinstance(b_gen, FloatGen) and not isinstance(b_gen, DoubleGen):
        query_parts.append('sum(b) over (order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as sum_col')

    # The option to IGNORE NULLS in NTH_VALUE is not available prior to Spark 3.2.1.
    if spark_version() >= "3.2.1":
        query_parts.append('NTH_VALUE(b, 1) IGNORE NULLS OVER '
                           '(ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS nth_1_ignore_nulls')

    assert_gpu_and_cpu_are_equal_sql(
        lambda spark : two_col_df(spark, UniqueLongGen(), b_gen, length=1024 * 14),
        "window_agg_table",
        'select ' +
        ', '.join(query_parts) +
        ' from window_agg_table ',
        validate_execs_in_gpu_plan = ['GpuRunningWindowExec'],
        conf = conf)


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn)  # Testing multiple batch sizes.
@pytest.mark.parametrize('a_gen', integral_gens + [string_gen, date_gen, timestamp_gen], ids=meta_idfn('data:'))
@allow_non_gpu(*non_utc_allow)
def test_running_window_without_partitions_runs_batched(a_gen, batch_size):
    """
    This tests the running window optimization as applied to RANGE-based window specifications,
    so long as the bounds are defined as [UNBOUNDED PRECEDING, CURRENT ROW]. 
    This test verifies the following:
      1. All tested aggregations invoke `GpuRunningWindowExec`, indicating that the running window
         optimization is in effect.
      2. The execution is batched, i.e. does not require that the entire input is loaded at once.
      3. The CPU and GPU runs produce the same results, regardless of batch size.
      
    Note that none of the ranking functions (including ROW_NUMBER) can be tested as a RANGE query.
    By definition, ranking functions require ROW frames.

    Note, also, that the order-by column is not generated via `UniqueLongGen()`.  This is specifically
    to test the case where `CURRENT ROW` might include more than a single row (as is possible in
    RANGE queries).  To mitigate the occurrence of non-deterministic results, the order-by column
    is also used in the aggregation.  This way, regardless of order, the same value is aggregated.
    """
    conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
            'spark.rapids.sql.castFloatToDecimal.enabled': True}
    query_parts = [
        'COUNT(a) OVER (ORDER BY a RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS count_col',
        'MIN(a) OVER (ORDER BY a RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS min_col',
        'MAX(a) OVER (ORDER BY a RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS max_col',
        'FIRST(a) OVER (ORDER BY a RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS first_keep_nulls',
        'FIRST(a, TRUE) OVER (ORDER BY a RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS first_ignore_nulls',
        'NTH_VALUE(a, 1) OVER (ORDER BY a RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS nth_1_keep_nulls',
        'LAST(a) OVER (ORDER BY a RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS last_keep_nulls',
        'LAST(a, TRUE) OVER (ORDER BY a RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS last_ignore_nulls',
    ]

    def must_test_sum_aggregation(gen):
        if isinstance(gen, DateType) or isinstance(gen.data_type, TimestampType):
            return False  # These types do not support SUM().
        # For Float/Double types, skip `SUM()` test. This is tested in test_running_float_sum_no_part.
        return isinstance(gen.data_type, NumericType) and \
            not isinstance(gen, FloatGen) and not isinstance(gen, DoubleGen)

    if must_test_sum_aggregation(a_gen):
        query_parts.append('SUM(a) OVER (ORDER BY a RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS sum_col')

    if spark_version() >= "3.2.1":
        query_parts.append('NTH_VALUE(a, 1) IGNORE NULLS OVER '
                           '(ORDER BY a RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS nth_1_ignore_nulls')

    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark, StructGen([('a', a_gen)], nullable=False), length=1024*14),
        "window_agg_table",
        'select ' +
        ', '.join(query_parts) +
        ' from window_agg_table ',
        validate_execs_in_gpu_plan = ['GpuRunningWindowExec'],
        conf = conf)


# Test that we can do a running window sum on floats and doubles.  This becomes problematic because we do the agg in parallel
# which means that the result can switch back and forth from Inf to not Inf depending on the order of aggregations.
# We test this by limiting the range of the values in the sum to never hit Inf, and by using abs so we don't have
# positive and negative values that interfere with each other.
# the order returned should be consistent because the data ends up in a single task (no partitioning)
@approximate_float
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
def test_running_float_sum_no_part(batch_size):
    conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
            'spark.rapids.sql.variableFloatAgg.enabled': True,
            'spark.rapids.sql.castFloatToDecimal.enabled': True}
    query_parts = ['a',
            'sum(cast(b as double)) over (order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as shrt_dbl_sum',
            'sum(abs(dbl)) over (order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as dbl_sum',
            'sum(cast(b as float)) over (order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as shrt_flt_sum',
            'sum(abs(flt)) over (order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as flt_sum']

    gen = StructGen([('a', UniqueLongGen()),('b', short_gen),('flt', float_gen),('dbl', double_gen)], nullable=False)
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark : gen_df(spark, gen, length=1024 * 14),
        "window_agg_table",
        'select ' +
        ', '.join(query_parts) +
        ' from window_agg_table ',
        validate_execs_in_gpu_plan = ['GpuRunningWindowExec'],
        conf = conf)


@approximate_float
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn)  # Tests different batch sizes.
def test_running_window_float_sum_without_partitions_runs_batched(batch_size):
    """
    This test is very similar to test_running_float_sum_no_part, except that it checks that RANGE window SUM
    aggregations can run in batched mode.
    Note that in the RANGE case, the test needs to check the case where there are repeats in the order-by column.
    This covers the case where `CURRENT ROW` might refer to multiple rows in the order-by column. This does introduce
    the possibility of non-deterministic results, because the ordering with repeated values isn't deterministic.
    This is mitigated by aggregating on the same column as the order-by column, such that the same value is aggregated
    for the repeated keys.
    """
    conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
            'spark.rapids.sql.variableFloatAgg.enabled': True,
            'spark.rapids.sql.castFloatToDecimal.enabled': True}
    query_parts = ['b',
                   'SUM(CAST(b AS DOUBLE)) OVER (ORDER BY b RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS shrt_dbl_sum',
                   'SUM(ABS(dbl)) OVER (ORDER BY b RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS dbl_sum',
                   'SUM(CAST(b AS FLOAT)) OVER (ORDER BY b RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS shrt_flt_sum',
                   'SUM(ABS(flt)) OVER (ORDER BY b RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS flt_sum']

    gen = StructGen([('b', short_gen), ('flt', float_gen), ('dbl', double_gen)], nullable=False)
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark : gen_df(spark, gen, length=1024 * 14),
        "window_agg_table",
        'SELECT ' +
        ', '.join(query_parts) +
        ' FROM window_agg_table ',
        validate_execs_in_gpu_plan=['GpuRunningWindowExec'],
        conf=conf)


# Rank aggregations are running window aggregations but they care about the ordering. In most tests we don't
# allow duplicate ordering, because that makes the results ambiguous. If two rows end up being switched even
# if the order-by column is the same then we can get different results for say a running sum. Here we are going
# to allow for duplication in the ordering, because there will be no other columns. This means that if you swtich
# rows it does not matter because the only time rows are switched is when the rows are exactly the same.
@pytest.mark.parametrize('data_gen',
                         all_basic_gens + [decimal_gen_32bit, orderable_decimal_gen_128bit],
                         ids=meta_idfn('data:'))
@allow_non_gpu(*non_utc_allow)
def test_window_running_rank_no_part(data_gen):
    # Keep the batch size small. We have tested these with operators with exact inputs already, this is mostly
    # testing the fixup operation.
    conf = {'spark.rapids.sql.batchSizeBytes': 1000}
    query_parts = ['a',
            'rank() over (order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as rank_val',
            'dense_rank() over (order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as dense_rank_val']

    # When generating the ordering try really hard to have duplicate values
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark : unary_op_df(spark, RepeatSeqGen(data_gen, length=500), length=1024 * 14),
        "window_agg_table",
        'select ' +
        ', '.join(query_parts) +
        ' from window_agg_table ',
        validate_execs_in_gpu_plan = ['GpuRunningWindowExec'],
        conf = conf)

# Rank aggregations are running window aggregations but they care about the ordering. In most tests we don't
# allow duplicate ordering, because that makes the results ambiguous. If two rows end up being switched even
# if the order-by column is the same then we can get different results for say a running sum. Here we are going
# to allow for duplication in the ordering, because there will be no other columns. This means that if you swtich
# rows it does not matter because the only time rows are switched is when the rows are exactly the same.
# In a distributed setup the order of the partitions returned might be different, so we must ignore the order
# but small batch sizes can make sort very slow, so do the final order by locally
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', all_basic_gens + [decimal_gen_32bit], ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_window_running_rank(data_gen):
    # Keep the batch size small. We have tested these with operators with exact inputs already, this is mostly
    # testing the fixup operation.
    conf = {'spark.rapids.sql.batchSizeBytes': 1000}
    query_parts = ['b', 'a',
            'rank() over (partition by b order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as rank_val',
            'dense_rank() over (partition by b order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as dense_rank_val']

    # When generating the ordering try really hard to have duplicate values
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark : two_col_df(spark, RepeatSeqGen(data_gen, length=500), RepeatSeqGen(data_gen, length=100), length=1024 * 14),
        "window_agg_table",
        'select ' +
        ', '.join(query_parts) +
        ' from window_agg_table ',
        validate_execs_in_gpu_plan = ['GpuRunningWindowExec'],
        conf = conf)


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
# This is for aggregations that work with a running window optimization. They don't need to be batched
# specially, but it only works if all of the aggregations can support this.
# In a distributed setup the order of the partitions returned might be different, so we must ignore the order
# but small batch sizes can make sort very slow, so do the final order by locally
@ignore_order(local=True)
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
@pytest.mark.parametrize('b_gen, c_gen', [(long_gen, x) for x in running_part_and_order_gens] +
        [(x, long_gen) for x in all_basic_gens + [decimal_gen_32bit]], ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_rows_based_running_window_partitioned(b_gen, c_gen, batch_size):
    conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
            'spark.rapids.sql.variableFloatAgg.enabled': True,
            'spark.rapids.sql.castFloatToDecimal.enabled': True}
    query_parts = ['b', 'a', 'row_number() over (partition by b order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as row_num',
            'rank() over (partition by b order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as rank_val',
            'dense_rank() over (partition by b order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as dense_rank_val',
            'count(c) over (partition by b order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as count_col',
            'min(c) over (partition by b order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as min_col',
            'max(c) over (partition by b order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as max_col',
            'FIRST(c) OVER (PARTITION BY b ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS first_keep_nulls',
            'FIRST(c, TRUE) OVER (PARTITION BY b ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS first_ignore_nulls',
            'NTH_VALUE(c, 1) OVER (PARTITION BY b ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS nth_1_keep_nulls',
            'LAST(c) OVER (PARTITION BY b ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS last_keep_nulls',
            'LAST(c, TRUE) OVER (PARTITION BY b ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS last_ignore_nulls']

    # Decimal precision can grow too large. Float and Double can get odd results for Inf/-Inf because of ordering
    if isinstance(c_gen.data_type, NumericType) and (not isinstance(c_gen, FloatGen)) and (not isinstance(c_gen, DoubleGen)) and (not isinstance(c_gen, DecimalGen)):
        query_parts.append('sum(c) over (partition by b order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as sum_col')

    # The option to IGNORE NULLS in NTH_VALUE is not available prior to Spark 3.2.1.
    if spark_version() >= "3.2.1":
        query_parts.append('NTH_VALUE(c, 1) IGNORE NULLS OVER '
                           '(PARTITION BY b ORDER BY a ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) AS nth_1_ignore_nulls')

    assert_gpu_and_cpu_are_equal_sql(
        lambda spark : three_col_df(spark, UniqueLongGen(), RepeatSeqGen(b_gen, length=100), c_gen, length=1024 * 14),
        "window_agg_table",
        'select ' +
        ', '.join(query_parts) +
        ' from window_agg_table ',
        validate_execs_in_gpu_plan = ['GpuRunningWindowExec'],
        conf = conf)



@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
@ignore_order(local=True)
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn)  # Test different batch sizes.
@pytest.mark.parametrize('part_gen', [int_gen, long_gen], ids=idfn)  # Partitioning is not really the focus of the test.
@pytest.mark.parametrize('order_gen', [x for x in all_basic_gens_no_null if x not in boolean_gens] + [decimal_gen_32bit], ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_range_running_window_runs_batched(part_gen, order_gen, batch_size):
    """
    This tests the running window optimization as applied to RANGE-based window specifications,
    so long as the bounds are defined as [UNBOUNDED PRECEDING, CURRENT ROW].
    This test verifies the following:
      1. All tested aggregations invoke `GpuRunningWindowExec`, indicating that the running window
         optimization is in effect.
      2. The execution is batched, i.e. does not require that the entire input is loaded at once.
      3. The CPU and GPU runs produce the same results, regardless of batch size.

    Note that none of the ranking functions (including ROW_NUMBER) can be tested as a RANGE query.
    By definition, ranking functions require ROW frames.

    Note, also, that the order-by column is not generated via `UniqueLongGen()`.  This is specifically
    to test the case where `CURRENT ROW` might include more than a single row (as is possible in
    RANGE queries).  To mitigate the occurrence of non-deterministic results, the order-by column
    is also used in the aggregation.  This way, regardless of order, the same value is aggregated.
    """
    conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
            'spark.rapids.sql.variableFloatAgg.enabled': True,
            'spark.rapids.sql.castFloatToDecimal.enabled': True}
    window = "(PARTITION BY p ORDER BY oby RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) "
    query_parts = [
        'p', 'oby',
        'COUNT(oby) OVER        ' + window + ' AS count_col',
        'MIN(oby) OVER          ' + window + ' AS min_col',
        'MAX(oby) OVER          ' + window + ' AS max_col',
        'FIRST(oby) OVER        ' + window + ' AS first_keep_nulls',
        'FIRST(oby, TRUE) OVER  ' + window + ' AS first_ignore_nulls',
        'NTH_VALUE(oby, 1) OVER ' + window + ' AS nth_1_keep_nulls',
        'LAST(oby) OVER         ' + window + ' AS last_keep_nulls',
        'LAST(oby, TRUE) OVER   ' + window + ' AS last_ignore_nulls',
    ]

    def must_test_sum_aggregation(gen):
        if isinstance(gen, DateType) or isinstance(gen.data_type, TimestampType):
            return False  # These types do not support SUM().
        # For Float/Double types, skip `SUM()` test. This is tested later.
        # Decimal precision can grow too large. Float and Double can get odd results for Inf/-Inf because of ordering
        return isinstance(gen.data_type, NumericType) and \
            not isinstance(gen, FloatGen) and not isinstance(gen, DoubleGen) and not isinstance(gen, DecimalGen)

    if must_test_sum_aggregation(order_gen):
        query_parts.append('SUM(oby) OVER ' + window + ' AS sum_col')

    # The option to IGNORE NULLS in NTH_VALUE is not available prior to Spark 3.2.1.
    if spark_version() >= "3.2.1":
        query_parts.append('NTH_VALUE(oby, 1) IGNORE NULLS OVER ' + window + ' AS nth_1_ignore_nulls')

    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark,
                             StructGen([('p', RepeatSeqGen(part_gen, length=100)),
                                        ('oby', order_gen)], nullable=False),
                             length=1024*14),
        "window_agg_table",
        'SELECT ' +
        ', '.join(query_parts) +
        ' FROM window_agg_table ',
        validate_execs_in_gpu_plan=['GpuRunningWindowExec'],
        conf=conf)


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
# Test that we can do a running window sum on floats and doubles and decimal. This becomes problematic because we do the agg in parallel
# which means that the result can switch back and forth from Inf to not Inf depending on the order of aggregations.
# We test this by limiting the range of the values in the sum to never hit Inf, and by using abs so we don't have
# positive and negative values that interfere with each other.
# decimal is problematic if the precision is so high it falls back to the CPU.
# In a distributed setup the order of the partitions returned might be different, so we must ignore the order
# but small batch sizes can make sort very slow, so do the final order by locally
@ignore_order(local=True)
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
def test_window_running_float_decimal_sum(batch_size):
    conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
            'spark.rapids.sql.variableFloatAgg.enabled': True,
            'spark.rapids.sql.castFloatToDecimal.enabled': True}
    query_parts = ['b', 'a',
            'sum(cast(c as double)) over (partition by b order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as dbl_sum',
            'sum(abs(dbl)) over (partition by b order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as dbl_sum',
            'sum(cast(c as float)) over (partition by b order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as flt_sum',
            'sum(abs(flt)) over (partition by b order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as flt_sum',
            'sum(cast(c as Decimal(6,1))) over (partition by b order by a rows between UNBOUNDED PRECEDING AND CURRENT ROW) as dec_sum']

    gen = StructGen([('a', UniqueLongGen()),('b', RepeatSeqGen(int_gen, length=1000)),('c', short_gen),('flt', float_gen),('dbl', double_gen)], nullable=False)
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark : gen_df(spark, gen, length=1024 * 14),
        "window_agg_table",
        'select ' +
        ', '.join(query_parts) +
        ' from window_agg_table ',
        validate_execs_in_gpu_plan = ['GpuRunningWindowExec'],
        conf = conf)


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
@approximate_float
@ignore_order(local=True)
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn)  # Test different batch sizes.
def test_range_running_window_float_decimal_sum_runs_batched(batch_size):
    """
    This test is very similar to test_window_running_float_decimal_sum, except that it checks that RANGE window SUM
    aggregations can run in batched mode.
    Note that in the RANGE case, the test needs to check the case where there are repeats in the order-by column.
    This covers the case where `CURRENT ROW` might refer to multiple rows in the order-by column. This does introduce
    the possibility of non-deterministic results, because the ordering with repeated values isn't deterministic.
    This is mitigated by aggregating on the same column as the order-by column, such that the same value is aggregated
    for the repeated keys.
    """
    conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
            'spark.rapids.sql.variableFloatAgg.enabled': True,
            'spark.rapids.sql.castFloatToDecimal.enabled': True}

    def window(oby_column):
        return "(PARTITION BY p ORDER BY " + oby_column + " RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) "

    query_parts = [
        'p', 'oby',
        'SUM(CAST(oby AS DOUBLE)) OVER       ' + window('CAST(oby AS DOUBLE)') + '       AS short_double_sum',
        'SUM(ABS(dbl)) OVER                  ' + window('ABS(dbl)') + '                  AS double_sum',
        'SUM(CAST(oby AS FLOAT)) OVER        ' + window('CAST(oby AS FLOAT)') + '        AS short_float_sum',
        'SUM(ABS(flt)) OVER                  ' + window('ABS(flt)') + '                  AS float_sum',
        'SUM(CAST(oby AS DECIMAL(6,1))) OVER ' + window('CAST(oby AS DECIMAL(6,1))') + ' AS dec_sum'
    ]

    gen = StructGen([('p', RepeatSeqGen(int_gen, length=1000)),
                     ('oby', short_gen),
                     ('flt', float_gen),
                     ('dbl', double_gen)], nullable=False)
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark, gen, length=1024 * 14),
        "window_agg_table",
        'SELECT ' +
        ', '.join(query_parts) +
        ' FROM window_agg_table ',
        validate_execs_in_gpu_plan=['GpuRunningWindowExec'],
        conf=conf)


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
# In a distributed setup the order of the partitions returned might be different, so we must ignore the order
# but small batch sizes can make sort very slow, so do the final order by locally
@ignore_order(local=True)
@approximate_float
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn) # set the batch size so we can test multiple stream batches
@pytest.mark.parametrize('c_gen', lead_lag_data_gens, ids=idfn)
@pytest.mark.parametrize('a_b_gen', part_and_order_gens, ids=meta_idfn('partAndOrderBy:'))
@allow_non_gpu(*non_utc_allow)
def test_multi_types_window_aggs_for_rows_lead_lag(a_b_gen, c_gen, batch_size):
    conf = {'spark.rapids.sql.batchSizeBytes': batch_size}
    data_gen = [
            ('a', RepeatSeqGen(a_b_gen, length=20)),
            ('b', a_b_gen),
            ('c', c_gen)]
    # By default for many operations a range of unbounded to unbounded is used
    # This will not work until https://github.com/NVIDIA/spark-rapids/issues/216
    # is fixed.

    # Ordering needs to include c because with nulls and especially on booleans
    # it is possible to get a different ordering when it is ambiguous.
    base_window_spec = Window.partitionBy('a').orderBy('b', 'c')
    inclusive_window_spec = base_window_spec.rowsBetween(-10, 100)

    def do_it(spark):
        df = gen_df(spark, data_gen, length=2048) \
            .withColumn('inc_count_1', f.count('*').over(inclusive_window_spec)) \
            .withColumn('inc_count_c', f.count('c').over(inclusive_window_spec)) \
            .withColumn('lead_5_c', f.lead('c', 5).over(base_window_spec)) \
            .withColumn('lag_1_c', f.lag('c', 1).over(base_window_spec)) \
            .withColumn('row_num', f.row_number().over(base_window_spec))

        if isinstance(c_gen, StructGen):
            """
            The MIN()/MAX() aggregations amount to a RANGE query. These are not
            currently supported on STRUCT columns.
            Also, LEAD()/LAG() defaults cannot currently be specified for STRUCT
            columns. `[ 10, 3.14159, "foobar" ]` isn't recognized as a valid STRUCT scalar.
            """
            return df.withColumn('lead_def_c', f.lead('c', 2, None).over(base_window_spec)) \
                     .withColumn('lag_def_c', f.lag('c', 4, None).over(base_window_spec))
        else:
            default_val = with_cpu_session(lambda spark: gen_scalar_value(c_gen, force_no_nulls=False))
            return df.withColumn('inc_max_c', f.max('c').over(inclusive_window_spec)) \
                     .withColumn('inc_min_c', f.min('c').over(inclusive_window_spec)) \
                     .withColumn('lead_def_c', f.lead('c', 2, default_val).over(base_window_spec)) \
                     .withColumn('lag_def_c', f.lag('c', 4, default_val).over(base_window_spec))

    assert_gpu_and_cpu_are_equal_collect(do_it, conf=conf)


struct_with_arrays = StructGen(children=[
                       ['child_int', int_gen],
                       ['child_time', date_gen],
                       ['child_string', string_gen],
                       ['child_array', ArrayGen(int_gen, max_length=10)]])

lead_lag_struct_with_arrays_gen = [struct_with_arrays,
                                   ArrayGen(struct_with_arrays, max_length=10),
                                   StructGen(children=[['child_struct', struct_with_arrays]])]


@ignore_order(local=True)
@approximate_float
@pytest.mark.parametrize('struct_gen', lead_lag_struct_with_arrays_gen, ids=idfn)
@pytest.mark.parametrize('a_b_gen', part_and_order_gens, ids=meta_idfn('partAndOrderBy:'))
@allow_non_gpu(*non_utc_allow)
def test_lead_lag_for_structs_with_arrays(a_b_gen, struct_gen):
    data_gen = [
        ('a', RepeatSeqGen(a_b_gen, length=20)),
        ('b', UniqueLongGen(nullable=False)),
        ('c', struct_gen)]
    # For many operations, a range of unbounded to unbounded is used by default.

    # Ordering needs to include `b` because with nulls and especially on booleans,
    # it is possible to get a different result when the ordering is ambiguous.
    base_window_spec = Window.partitionBy('a').orderBy('b')

    def do_it(spark):
        return gen_df(spark, data_gen, length=2048) \
            .withColumn('lead_5_c', f.lead('c', 5).over(base_window_spec)) \
            .withColumn('lag_1_c', f.lag('c', 1).over(base_window_spec))

    assert_gpu_and_cpu_are_equal_collect(do_it)


lead_lag_array_data_gens =\
    [ArrayGen(sub_gen, max_length=10) for sub_gen in lead_lag_data_gens] + \
    [ArrayGen(ArrayGen(sub_gen, max_length=10), max_length=10) for sub_gen in lead_lag_data_gens] + \
    [ArrayGen(ArrayGen(ArrayGen(sub_gen, max_length=10), max_length=10), max_length=10) \
        for sub_gen in lead_lag_data_gens]


@ignore_order(local=True)
@pytest.mark.parametrize('d_gen', lead_lag_array_data_gens, ids=meta_idfn('agg:'))
@pytest.mark.parametrize('c_gen', [UniqueLongGen()], ids=meta_idfn('orderBy:'))
@pytest.mark.parametrize('b_gen', [long_gen], ids=meta_idfn('orderBy:'))
@pytest.mark.parametrize('a_gen', [long_gen], ids=meta_idfn('partBy:'))
@allow_non_gpu(*non_utc_allow)
def test_window_aggs_for_rows_lead_lag_on_arrays(a_gen, b_gen, c_gen, d_gen):
    data_gen = [
            ('a', RepeatSeqGen(a_gen, length=20)),
            ('b', b_gen),
            ('c', c_gen),
            ('d', d_gen),
            ('d_default', d_gen)]

    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark, data_gen, length=2048),
        "window_agg_table",
        '''
        SELECT
            LEAD(d, 5) OVER (PARTITION by a ORDER BY b,c) lead_d_5,
            LEAD(d, 2, d_default) OVER (PARTITION by a ORDER BY b,c) lead_d_2_default,
            LAG(d, 5) OVER (PARTITION by a ORDER BY b,c) lag_d_5,
            LAG(d, 2, d_default) OVER (PARTITION by a ORDER BY b,c) lag_d_2_default
        FROM window_agg_table
        ''')


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
# lead and lag don't currently work for string columns, so redo the tests, but just for strings
# without lead and lag
# In a distributed setup the order of the partitions returned might be different, so we must ignore the order
# but small batch sizes can make sort very slow, so do the final order by locally
@ignore_order(local=True)
@approximate_float
@pytest.mark.parametrize('c_gen', [string_gen], ids=idfn)
@pytest.mark.parametrize('a_b_gen', part_and_order_gens, ids=meta_idfn('partAndOrderBy:'))
@allow_non_gpu(*non_utc_allow)
def test_multi_types_window_aggs_for_rows(a_b_gen, c_gen):
    data_gen = [
            ('a', RepeatSeqGen(a_b_gen, length=20)),
            ('b', a_b_gen),
            ('c', c_gen)]
    # By default for many operations a range of unbounded to unbounded is used
    # This will not work until https://github.com/NVIDIA/spark-rapids/issues/216
    # is fixed.

    # Ordering needs to include c because with nulls and especially on booleans
    # it is possible to get a different ordering when it is ambiguous
    baseWindowSpec = Window.partitionBy('a').orderBy('b', 'c')
    inclusiveWindowSpec = baseWindowSpec.rowsBetween(-10, 100)

    def do_it(spark):
        return gen_df(spark, data_gen, length=2048) \
                .withColumn('inc_count_1', f.count('*').over(inclusiveWindowSpec)) \
                .withColumn('inc_count_c', f.count('c').over(inclusiveWindowSpec)) \
                .withColumn('inc_max_c', f.max('c').over(inclusiveWindowSpec)) \
                .withColumn('inc_min_c', f.min('c').over(inclusiveWindowSpec)) \
                .withColumn('rank_val', f.rank().over(baseWindowSpec)) \
                .withColumn('dense_rank_val', f.dense_rank().over(baseWindowSpec)) \
                .withColumn('percent_rank_val', f.percent_rank().over(baseWindowSpec)) \
                .withColumn('row_num', f.row_number().over(baseWindowSpec))
    assert_gpu_and_cpu_are_equal_collect(do_it)


def test_percent_rank_no_part_multiple_batches():
    data_gen = [('a', long_gen)]
    # The goal of this is to have multiple batches so we can verify that the code
    # is working properly, but not so large that it takes forever to run.
    baseWindowSpec = Window.orderBy('a')

    def do_it(spark):
        return gen_df(spark, data_gen, length=8000) \
                .withColumn('percent_rank_val', f.percent_rank().over(baseWindowSpec))
    assert_gpu_and_cpu_are_equal_collect(do_it, conf = {'spark.rapids.sql.batchSizeBytes': '100'})

def test_percent_rank_single_part_multiple_batches():
    data_gen = [('a', long_gen)]
    # The goal of this is to have multiple batches so we can verify that the code
    # is working properly, but not so large that it takes forever to run.
    baseWindowSpec = Window.partitionBy('b').orderBy('a')

    def do_it(spark):
        return gen_df(spark, data_gen, length=8000) \
                .withColumn('b', f.lit(1)) \
                .withColumn('percent_rank_val', f.percent_rank().over(baseWindowSpec))
    assert_gpu_and_cpu_are_equal_collect(do_it, conf = {'spark.rapids.sql.batchSizeBytes': '100'})

@pytest.mark.skipif(is_before_spark_320(), reason="Only in Spark 3.2.0 is IGNORE NULLS supported for lead and lag by Spark")
@allow_non_gpu('WindowExec', 'Alias', 'WindowExpression', 'Lead', 'Literal', 'WindowSpecDefinition', 'SpecifiedWindowFrame', *non_utc_allow)
@ignore_order(local=True)
@pytest.mark.parametrize('d_gen', all_basic_gens, ids=meta_idfn('agg:'))
@pytest.mark.parametrize('c_gen', [UniqueLongGen()], ids=meta_idfn('orderBy:'))
@pytest.mark.parametrize('b_gen', [long_gen], ids=meta_idfn('orderBy:'))
@pytest.mark.parametrize('a_gen', [long_gen], ids=meta_idfn('partBy:'))
def test_window_aggs_lead_ignore_nulls_fallback(a_gen, b_gen, c_gen, d_gen):
    data_gen = [
            ('a', RepeatSeqGen(a_gen, length=20)),
            ('b', b_gen),
            ('c', c_gen),
            ('d', d_gen)]

    assert_gpu_sql_fallback_collect(
        lambda spark: gen_df(spark, data_gen),
        'Lead',
        "window_agg_table",
        '''
        SELECT
            LEAD(d, 5) IGNORE NULLS OVER (PARTITION by a ORDER BY b,c) lead_d_5
        FROM window_agg_table
        ''')

@pytest.mark.skipif(is_before_spark_320(), reason="Only in Spark 3.2.0 is IGNORE NULLS supported for lead and lag by Spark")
@allow_non_gpu('WindowExec', 'Alias', 'WindowExpression', 'Lag', 'Literal', 'WindowSpecDefinition', 'SpecifiedWindowFrame', *non_utc_allow)
@ignore_order(local=True)
@pytest.mark.parametrize('d_gen', all_basic_gens, ids=meta_idfn('agg:'))
@pytest.mark.parametrize('c_gen', [UniqueLongGen()], ids=meta_idfn('orderBy:'))
@pytest.mark.parametrize('b_gen', [long_gen], ids=meta_idfn('orderBy:'))
@pytest.mark.parametrize('a_gen', [long_gen], ids=meta_idfn('partBy:'))
def test_window_aggs_lag_ignore_nulls_fallback(a_gen, b_gen, c_gen, d_gen):
    data_gen = [
            ('a', RepeatSeqGen(a_gen, length=20)),
            ('b', b_gen),
            ('c', c_gen),
            ('d', d_gen)]

    assert_gpu_sql_fallback_collect(
        lambda spark: gen_df(spark, data_gen),
        'Lag',
        "window_agg_table",
        '''
        SELECT
            LAG(d, 5) IGNORE NULLS OVER (PARTITION by a ORDER BY b,c) lag_d_5
        FROM window_agg_table
        ''')


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
# Test for RANGE queries, with timestamp order-by expressions.
# In a distributed setup the order of the partitions returned might be different, so we must ignore the order
# but small batch sizes can make sort very slow, so do the final order by locally
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [_grpkey_longs_with_timestamps,
                                      pytest.param(_grpkey_longs_with_nullable_timestamps)],
                                      ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_window_aggs_for_ranges_timestamps(data_gen):
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark, data_gen, length=2048),
        "window_agg_table",
        'select '
        ' sum(c) over '
        '   (partition by a order by b asc  '
        '       range between interval 1 DAY 5 HOUR 3 MINUTE 2 SECOND 1 MILLISECOND 5 MICROSECOND preceding '
        '             and interval 1 DAY 5 HOUR 3 MINUTE 2 SECOND 1 MILLISECOND 5 MICROSECOND following) as sum_c_asc, '
        ' avg(c) over '
        '   (partition by a order by b asc  '
        '       range between interval 1 DAY 5 HOUR 3 MINUTE 2 SECOND 1 MILLISECOND 5 MICROSECOND preceding '
        '            and interval 1 DAY 5 HOUR 3 MINUTE 2 SECOND 1 MILLISECOND 5 MICROSECOND following) as avg_c_asc, '
        ' max(c) over '
        '   (partition by a order by b desc '
        '       range between interval 2 DAY 5 HOUR 3 MINUTE 2 SECOND 1 MILLISECOND 5 MICROSECOND preceding '
        '            and interval 1 DAY 5 HOUR 3 MINUTE 2 SECOND 1 MILLISECOND 5 MICROSECOND following) as max_c_desc, '
        ' min(c) over '
        '   (partition by a order by b asc  '
        '       range between interval 2 DAY 5 HOUR 3 MINUTE 2 SECOND 1 MILLISECOND 5 MICROSECOND preceding '
        '            and current row) as min_c_asc, '
        ' count(1) over '
        '   (partition by a order by b asc  '
        '       range between  CURRENT ROW and UNBOUNDED following) as count_1_asc, '
        ' count(c) over '
        '   (partition by a order by b asc  '
        '       range between  CURRENT ROW and UNBOUNDED following) as count_c_asc, '
        ' avg(c) over '
        '   (partition by a order by b asc  '
        '       range between UNBOUNDED preceding and CURRENT ROW) as avg_c_unbounded, '
        ' sum(c) over '
        '   (partition by a order by b asc  '
        '       range between UNBOUNDED preceding and CURRENT ROW) as sum_c_unbounded, '
        ' max(c) over '
        '   (partition by a order by b asc  '
        '       range between UNBOUNDED preceding and UNBOUNDED following) as max_c_unbounded '
        'from window_agg_table',
        conf = {'spark.rapids.sql.castFloatToDecimal.enabled': True})


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
# In a distributed setup the order of the partitions returned might be different, so we must ignore the order
# but small batch sizes can make sort very slow, so do the final order by locally
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [
  _grpkey_longs_with_nullable_decimals,
  _grpkey_longs_with_nullable_larger_decimals,
  pytest.param(_grpkey_longs_with_nullable_largest_decimals,
    marks=pytest.mark.xfail(
      condition=is_databricks113_or_later(),
      reason='https://github.com/NVIDIA/spark-rapids/issues/7429')),
  _grpkey_longs_with_nullable_floats,
  _grpkey_longs_with_nullable_doubles
], ids=idfn)
def test_window_aggregations_for_decimal_and_float_ranges(data_gen):
    """
    Tests for range window aggregations, with DECIMAL/FLOATING POINT order by columns.
    The table schema used:
      a: Group By column
      b: Order By column (decimals, floats, doubles)
      c: Aggregation column (decimals or ints)

    Since this test is for the order-by column type, and not for each specific windowing aggregation,
    we use COUNT(1) throughout the test, for different window widths and ordering.
    Some other aggregation functions are thrown in for variety.
    """
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark, data_gen, length=2048),
        "window_agg_table",
        'SELECT '
        ' COUNT(1) OVER (PARTITION BY a ORDER BY b ASC RANGE BETWEEN 10.2345 PRECEDING AND 6.7890 FOLLOWING), '
        ' COUNT(1) OVER (PARTITION BY a ORDER BY b ASC), '
        ' COUNT(1) OVER (PARTITION BY a ORDER BY b ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), '
        ' COUNT(1) OVER (PARTITION BY a ORDER BY b ASC RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), '
        ' COUNT(1) OVER (PARTITION BY a ORDER BY b DESC RANGE BETWEEN 10.2345 PRECEDING AND 6.7890 FOLLOWING), '
        ' COUNT(1) OVER (PARTITION BY a ORDER BY b DESC), '
        ' COUNT(1) OVER (PARTITION BY a ORDER BY b DESC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), '
        ' COUNT(1) OVER (PARTITION BY a ORDER BY b DESC RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),'
        ' COUNT(c) OVER (PARTITION BY a ORDER BY b RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), '
        ' SUM(c)   OVER (PARTITION BY a ORDER BY b RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), '
        ' MIN(c)   OVER (PARTITION BY a ORDER BY b RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), '
        ' MAX(c)   OVER (PARTITION BY a ORDER BY b RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), '
        ' RANK()   OVER (PARTITION BY a ORDER BY b) '
        'FROM window_agg_table',
        conf={})


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
# In a distributed setup the order of the partitions returned might be different, so we must ignore the order
# but small batch sizes can make sort very slow, so do the final order by locally
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [
  pytest.param(_grpkey_longs_with_nullable_largest_decimals,
    marks=pytest.mark.xfail(
      condition=is_databricks113_or_later(),
      reason='https://github.com/NVIDIA/spark-rapids/issues/7429'))
], ids=idfn)
def test_window_aggregations_for_big_decimal_ranges(data_gen):
    """
    Tests for range window aggregations, with DECIMAL order by columns.
    The table schema used:
      a: Group By column
      b: Order By column (decimal)
      c: Aggregation column (incidentally, also decimal)

    Since this test is for the order-by column type, and not for each specific windowing aggregation,
    we use COUNT(1) throughout the test, for different window widths and ordering.
    Some other aggregation functions are thrown in for variety.
    """
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark, data_gen, length=2048),
        "window_agg_table",
        'SELECT '
        ' COUNT(1) OVER (PARTITION BY a ORDER BY b ASC '
        '                RANGE BETWEEN 12345678901234567890123456789012345.12 PRECEDING '
        '                          AND 11111111112222222222333333333344444.12 FOLLOWING) '
        'FROM window_agg_table',
        conf={})


_gen_data_for_collect_list = [
    ('a', RepeatSeqGen(LongGen(), length=20)),
    ('b', UniqueLongGen()),
    ('c_bool', BooleanGen()),
    ('c_short', ShortGen()),
    ('c_int', IntegerGen()),
    ('c_long', LongGen()),
    ('c_date', DateGen()),
    ('c_ts', TimestampGen()),
    ('c_byte', ByteGen()),
    ('c_string', StringGen()),
    ('c_float', FloatGen()),
    ('c_double', DoubleGen()),
    ('c_decimal_32', DecimalGen(precision=8, scale=3)),
    ('c_decimal_64', decimal_gen_64bit),
    ('c_decimal_128', decimal_gen_128bit),
    ('c_struct', StructGen(children=[
        ['child_int', IntegerGen()],
        ['child_time', DateGen()],
        ['child_string', StringGen()],
        ['child_decimal_32', DecimalGen(precision=8, scale=3)],
        ['child_decimal_64', decimal_gen_64bit],
        ['child_decimal_128', decimal_gen_128bit]])),
    ('c_array', ArrayGen(int_gen)),
    ('c_map', simple_string_to_string_map_gen)]


# SortExec does not support array type, so sort the result locally.
@ignore_order(local=True)
@allow_non_gpu(*non_utc_allow)
def test_window_aggs_for_rows_collect_list():
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark : gen_df(spark, _gen_data_for_collect_list),
        "window_collect_table",
        '''
        select
          collect_list(c_bool) over
            (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_bool,
          collect_list(c_short) over
            (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_short,
          collect_list(c_int) over
            (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_int,
          collect_list(c_long) over
            (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_long,
          collect_list(c_date) over
            (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_date,
          collect_list(c_ts) over
            (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_ts,
          collect_list(c_byte) over
            (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_byte,
          collect_list(c_string) over
            (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_string,
          collect_list(c_float) over
            (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_float,
          collect_list(c_double) over
            (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_double,
          collect_list(c_decimal_32) over
            (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_decimal_32,
          collect_list(c_decimal_64) over
            (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_decimal_64,
          collect_list(c_decimal_128) over
            (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_decimal_128,
          collect_list(c_struct) over
            (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_struct,
          collect_list(c_array) over
            (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_array,
          collect_list(c_map) over
            (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as collect_map
        from window_collect_table
        ''',
        conf={'spark.rapids.sql.window.collectList.enabled': True})


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
# SortExec does not support array type, so sort the result locally.
@ignore_order(local=True)
# This test is more directed at Databricks and their running window optimization instead of ours
# this is why we do not validate that we inserted in a GpuRunningWindowExec, yet.
@allow_non_gpu(*non_utc_allow)
def test_running_window_function_exec_for_all_aggs():
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark : gen_df(spark, _gen_data_for_collect_list),
        "window_collect_table",
        '''
        select
          sum(c_int) over
            (partition by a order by b,c_int rows between UNBOUNDED PRECEDING AND CURRENT ROW) as sum_int,
          min(c_long) over
            (partition by a order by b,c_int rows between UNBOUNDED PRECEDING AND CURRENT ROW) as min_long,
          max(c_date) over
            (partition by a order by b,c_int rows between UNBOUNDED PRECEDING AND CURRENT ROW) as max_date,
          count(1) over
            (partition by a order by b,c_int rows between UNBOUNDED PRECEDING AND CURRENT ROW) as count_1,
          count(*) over
            (partition by a order by b,c_int rows between UNBOUNDED PRECEDING AND CURRENT ROW) as count_star,
          row_number() over
            (partition by a order by b,c_int) as row_num,
          rank() over
            (partition by a order by b,c_int) as rank_val,
          dense_rank() over
            (partition by a order by b,c_int) as dense_rank_val,
          collect_list(c_float) over
            (partition by a order by b,c_int rows between UNBOUNDED PRECEDING AND CURRENT ROW) as collect_float,
          collect_list(c_decimal_32) over
            (partition by a order by b,c_int rows between UNBOUNDED PRECEDING AND CURRENT ROW) as collect_decimal_32,
          collect_list(c_decimal_64) over
            (partition by a order by b,c_int rows between UNBOUNDED PRECEDING AND CURRENT ROW) as collect_decimal_64,
          collect_list(c_decimal_128) over
            (partition by a order by b,c_int rows between UNBOUNDED PRECEDING AND CURRENT ROW) as collect_decimal_128,
          collect_list(c_struct) over
            (partition by a order by b,c_int rows between UNBOUNDED PRECEDING AND CURRENT ROW) as collect_struct
        from window_collect_table
        ''',
        conf={'spark.rapids.sql.window.collectList.enabled': True})


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
# Test the Databricks WindowExec which combines a WindowExec with a ProjectExec and provides the output
# fields that we need to handle with an extra GpuProjectExec and we need the input expressions to compute
# a window function of another window function case
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', integral_gens, ids=idfn)
def test_join_sum_window_of_window(data_gen):
    def do_it(spark):
        agg_table = gen_df(spark, StructGen([('a_1', UniqueLongGen()), ('c', data_gen)], nullable=False))
        part_table = gen_df(spark, StructGen([('a_2', UniqueLongGen()), ('b', byte_gen)], nullable=False))
        agg_table.createOrReplaceTempView("agg")
        part_table.createOrReplaceTempView("part")
        # Note that if we include `c` in the select clause here (the output projection), the bug described
        # in https://github.com/NVIDIA/spark-rapids/issues/6531 does not manifest
        return spark.sql("""
        select
            b,
            sum(c) as sum_c,
            sum(c)/sum(sum(c)) over (partition by b) as ratio_sum,
            (b + c)/sum(sum(c)) over (partition by b) as ratio_bc
        from agg, part
        where a_1 = a_2
        group by b, c
        order by b, ratio_sum, ratio_bc""")

    assert_gpu_and_cpu_are_equal_collect(do_it)

# Generates some repeated values to test the deduplication of GpuCollectSet.
# And GpuCollectSet does not yet support struct type.
_gen_data_for_collect_set = [
    ('a', RepeatSeqGen(LongGen(), length=20)),
    ('b', UniqueLongGen()),
    ('c_bool', RepeatSeqGen(BooleanGen(), length=15)),
    ('c_int', RepeatSeqGen(IntegerGen(), length=15)),
    ('c_long', RepeatSeqGen(LongGen(), length=15)),
    ('c_short', RepeatSeqGen(ShortGen(), length=15)),
    ('c_date', RepeatSeqGen(DateGen(), length=15)),
    ('c_timestamp', RepeatSeqGen(TimestampGen(), length=15)),
    ('c_byte', RepeatSeqGen(ByteGen(), length=15)),
    ('c_string', RepeatSeqGen(StringGen(), length=15)),
    ('c_float', RepeatSeqGen(FloatGen(), length=15)),
    ('c_double', RepeatSeqGen(DoubleGen(), length=15)),
    ('c_decimal_32', RepeatSeqGen(DecimalGen(precision=8, scale=3), length=15)),
    ('c_decimal_64', RepeatSeqGen(decimal_gen_64bit, length=15)),
    ('c_decimal_128', RepeatSeqGen(decimal_gen_128bit, length=15)),
    # case to verify the NAN_UNEQUAL strategy
    ('c_fp_nan', RepeatSeqGen(FloatGen().with_special_case(math.nan, 200.0), length=5)),
]

_gen_data_for_collect_set_nested = [
    ('a', RepeatSeqGen(LongGen(), length=20)),
    ('b', UniqueLongGen()),
    ('c_int', RepeatSeqGen(IntegerGen(), length=15)),
    ('c_struct_array_1', RepeatSeqGen(struct_array_gen, length=15)),
    ('c_struct_array_2', RepeatSeqGen(StructGen([
        ['c0', struct_array_gen], ['c1', int_gen]]), length=14)),
    ('c_array_struct', RepeatSeqGen(ArrayGen(all_basic_struct_gen), length=15)),
    ('c_array_array_bool', RepeatSeqGen(ArrayGen(ArrayGen(BooleanGen())), length=15)),
    ('c_array_array_int', RepeatSeqGen(ArrayGen(ArrayGen(IntegerGen())), length=15)),
    ('c_array_array_long', RepeatSeqGen(ArrayGen(ArrayGen(LongGen())), length=15)),
    ('c_array_array_short', RepeatSeqGen(ArrayGen(ArrayGen(ShortGen())), length=15)),
    ('c_array_array_date', RepeatSeqGen(ArrayGen(ArrayGen(DateGen())), length=15)),
    ('c_array_array_timestamp', RepeatSeqGen(ArrayGen(ArrayGen(TimestampGen())), length=15)),
    ('c_array_array_byte', RepeatSeqGen(ArrayGen(ArrayGen(ByteGen())), length=15)),
    ('c_array_array_string', RepeatSeqGen(ArrayGen(ArrayGen(StringGen())), length=15)),
    ('c_array_array_float', RepeatSeqGen(ArrayGen(ArrayGen(FloatGen())), length=15)),
    ('c_array_array_double', RepeatSeqGen(ArrayGen(ArrayGen(DoubleGen())), length=15)),
    ('c_array_array_decimal_32', RepeatSeqGen(ArrayGen(ArrayGen(DecimalGen(precision=8, scale=3))), length=15)),
    ('c_array_array_decimal_64', RepeatSeqGen(ArrayGen(ArrayGen(decimal_gen_64bit)), length=15)),
    ('c_array_array_decimal_128', RepeatSeqGen(ArrayGen(ArrayGen(decimal_gen_128bit)), length=15)),
]

# SortExec does not support array type, so sort the result locally.
@ignore_order(local=True)
@allow_non_gpu(*non_utc_allow)
def test_window_aggs_for_rows_collect_set():
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark, _gen_data_for_collect_set),
        "window_collect_table",
        '''
        select a, b,
            sort_array(cc_bool),
            sort_array(cc_int),
            sort_array(cc_long),
            sort_array(cc_short),
            sort_array(cc_date),
            sort_array(cc_ts),
            sort_array(cc_byte),
            sort_array(cc_str),
            sort_array(cc_float),
            sort_array(cc_double),
            sort_array(cc_decimal_32),
            sort_array(cc_decimal_64),
            sort_array(cc_decimal_128),
            sort_array(cc_fp_nan)
        from (
            select a, b,
              collect_set(c_bool) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_bool,
              collect_set(c_int) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_int,
              collect_set(c_long) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_long,
              collect_set(c_short) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_short,
              collect_set(c_date) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_date,
              collect_set(c_timestamp) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_ts,
              collect_set(c_byte) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_byte,
              collect_set(c_string) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_str,
              collect_set(c_float) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_float,
              collect_set(c_double) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_double,
              collect_set(c_decimal_32) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_decimal_32,
              collect_set(c_decimal_64) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_decimal_64,
              collect_set(c_decimal_128) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_decimal_128,
              collect_set(c_fp_nan) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_fp_nan
            from window_collect_table
        ) t
        ''',
        conf={'spark.rapids.sql.window.collectSet.enabled': True})


@ignore_order(local=True)
@allow_non_gpu(*non_utc_allow)
def test_window_aggs_for_fully_unbounded_partitioned_collect_set():
    """
    Test that confirms that `collect_set` window aggregation, when run over UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING
    runs through the `GpuUnboundedToUnboundedAggWindowExec` (which optimizes it to run via sort-based group-by
    aggregations).
    Note: This optimization only holds for the partitioned case.  Unpartitioned windows are not supported yet.
    """
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark, _gen_data_for_collect_set, length=2048),
        "window_collect_table",
        '''
        select a, b,
            sort_array(cc_bool),
            sort_array(cc_int),
            sort_array(cc_long),
            sort_array(cc_short),
            sort_array(cc_date),
            sort_array(cc_ts),
            sort_array(cc_byte),
            sort_array(cc_str),
            sort_array(cc_float),
            sort_array(cc_double),
            sort_array(cc_decimal_32),
            sort_array(cc_decimal_64),
            sort_array(cc_decimal_128),
            sort_array(cc_fp_nan)
        from (
            select a, b,
               collect_set(c_bool) over
                 (partition by a order by b,c_int rows between UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) as cc_bool,
               collect_set(c_int) over
                 (partition by a order by b,c_int rows between UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) as cc_int,
               collect_set(c_long) over
                 (partition by a order by b,c_int rows between UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) as cc_long,
               collect_set(c_short) over
                 (partition by a order by b,c_int rows between UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) as cc_short,
               collect_set(c_date) over
                 (partition by a order by b,c_int rows between UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) as cc_date,
               collect_set(c_timestamp) over
                 (partition by a order by b,c_int rows between UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) as cc_ts,
               collect_set(c_byte) over
                 (partition by a order by b,c_int rows between UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) as cc_byte,
               collect_set(c_string) over
                 (partition by a order by b,c_int rows between UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) as cc_str,
               collect_set(c_float) over
                 (partition by a order by b,c_int rows between UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) as cc_float,
               collect_set(c_double) over
                 (partition by a order by b,c_int rows between UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) as cc_double,
               collect_set(c_decimal_32) over
                 (partition by a order by b,c_int rows between UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) as cc_decimal_32,
               collect_set(c_decimal_64) over
                 (partition by a order by b,c_int rows between UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) as cc_decimal_64,
               collect_set(c_decimal_128) over
                 (partition by a order by b,c_int rows between UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) as cc_decimal_128,
               collect_set(c_fp_nan) over
                 (partition by a order by b,c_int rows between UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) as cc_fp_nan
            from window_collect_table
        ) t
        ''',
        conf={'spark.rapids.sql.window.collectSet.enabled': True,
              'spark.rapids.sql.window.unboundedAgg.enabled': True,
              'spark.sql.parquet.int96RebaseModeInWrite': 'LEGACY'},
        validate_execs_in_gpu_plan=['GpuUnboundedToUnboundedAggWindowExec'])


@ignore_order(local=True)
@allow_non_gpu(*non_utc_allow)
def test_window_aggs_for_fully_unbounded_unpartitioned_collect_set():
    """
    Test that confirms that `collect_set` window aggregation, when run over UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING
    falls back to GpuWindowExec, if no partition spec is specified.
    """
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark, _gen_data_for_collect_set, length=2048),
        "window_collect_table",
        '''
        select a, b,
            sort_array(cc_int),
            sort_array(cc_long),
            sort_array(cc_short)
        from (
            select a, b,
               collect_set(c_int) over
                 (order by b,c_int rows between UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) as cc_int,
               collect_set(c_long) over
                 (order by b,c_int rows between UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) as cc_long,
               collect_set(c_short) over
                 (order by b,c_int rows between UNBOUNDED PRECEDING and UNBOUNDED FOLLOWING) as cc_short
            from window_collect_table
        ) t
        ''',
        conf={'spark.rapids.sql.window.collectSet.enabled': True,
              'spark.rapids.sql.window.unboundedAgg.enabled': True,
              'spark.sql.parquet.int96RebaseModeInWrite': 'LEGACY'},
        validate_execs_in_gpu_plan=['GpuWindowExec'])


# Note, using sort_array() on the CPU, because sort_array() does not yet
# support sorting certain nested/arbitrary types on the GPU
# See https://github.com/NVIDIA/spark-rapids/issues/3715
# and https://github.com/rapidsai/cudf/issues/11222
@ignore_order(local=True, arrays=[
        "cc_struct_array_1",
        "cc_struct_array_2",
        "cc_array_struct",
        "cc_array_array_bool",
        "cc_array_array_int",
        "cc_array_array_long",
        "cc_array_array_short",
        "cc_array_array_date",
        "cc_array_array_ts",
        "cc_array_array_byte",
        "cc_array_array_str",
        "cc_array_array_float",
        "cc_array_array_double",
        "cc_array_array_decimal_32",
        "cc_array_array_decimal_64",
        "cc_array_array_decimal_128"
])
@allow_non_gpu("ProjectExec", *non_utc_allow)
def test_window_aggs_for_rows_collect_set_nested_array():
    conf = copy_and_update(_float_conf, {
        "spark.rapids.sql.castFloatToString.enabled": "true",
        'spark.rapids.sql.window.collectSet.enabled': "true"
    })

    def do_it(spark):
        df = gen_df(spark, _gen_data_for_collect_set_nested, length=512)
        df.createOrReplaceTempView("window_collect_table")
        return spark.sql(
            """select a, b,
              collect_set(c_struct_array_1) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_struct_array_1,
              collect_set(c_struct_array_2) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_struct_array_2,
              collect_set(c_array_struct) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_array_struct,
              collect_set(c_array_array_bool) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_array_array_bool,
              collect_set(c_array_array_int) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_array_array_int,
              collect_set(c_array_array_long) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_array_array_long,
              collect_set(c_array_array_short) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_array_array_short,
              collect_set(c_array_array_date) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_array_array_date,
              collect_set(c_array_array_timestamp) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_array_array_ts,
              collect_set(c_array_array_byte) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_array_array_byte,
              collect_set(c_array_array_string) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_array_array_str,
              collect_set(c_array_array_float) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_array_array_float,
              collect_set(c_array_array_double) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_array_array_double,
              collect_set(c_array_array_decimal_32) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_array_array_decimal_32,
              collect_set(c_array_array_decimal_64) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_array_array_decimal_64,
              collect_set(c_array_array_decimal_128) over
                (partition by a order by b,c_int rows between CURRENT ROW and UNBOUNDED FOLLOWING) as cc_array_array_decimal_128
        from window_collect_table
        """)
    assert_gpu_and_cpu_are_equal_collect(do_it, conf=conf)


# In a distributed setup the order of the partitions returned might be different, so we must ignore the order
# but small batch sizes can make sort very slow, so do the final order by locally
@ignore_order(local=True)
# Arrays and struct of struct (more than single level nesting) are not supported
@pytest.mark.parametrize('part_gen', [ArrayGen(long_gen), StructGen([["a", StructGen([["a1", long_gen]])]])], ids=meta_idfn('partBy:'))
# For arrays the sort and hash partition are also not supported
@allow_non_gpu('WindowExec', 'Alias', 'WindowExpression', 'AggregateExpression', 'Count', 'WindowSpecDefinition', 'SpecifiedWindowFrame', 'Literal', 'SortExec', 'SortOrder', 'ShuffleExchangeExec', 'HashPartitioning')
def test_nested_part_fallback(part_gen):
    data_gen = [
            ('a', RepeatSeqGen(part_gen, length=20)),
            ('b', UniqueLongGen()),
            ('c', int_gen)]

    window_spec = Window.partitionBy('a').orderBy('b').rowsBetween(-5, 5)

    def do_it(spark):
        return gen_df(spark, data_gen, length=2048) \
            .withColumn('rn', f.count('c').over(window_spec))

    assert_gpu_fallback_collect(do_it, 'WindowExec')


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
@ignore_order(local=True)
# single-level structs (no nested structs) are now supported by the plugin
@pytest.mark.parametrize('part_gen', [StructGen([["a", long_gen]])], ids=meta_idfn('partBy:'))
def test_nested_part_struct(part_gen):
    data_gen = [
            ('a', RepeatSeqGen(part_gen, length=20)),
            ('b', UniqueLongGen()),
            ('c', int_gen)]
    window_spec = Window.partitionBy('a').orderBy('b').rowsBetween(-5, 5)

    def do_it(spark):
        return gen_df(spark, data_gen, length=2048) \
            .withColumn('rn', f.count('c').over(window_spec))

    assert_gpu_and_cpu_are_equal_collect(do_it)

# In a distributed setup the order of the partitions returend might be different, so we must ignore the order
# but small batch sizes can make sort very slow, so do the final order by locally
@ignore_order(local=True)
@pytest.mark.parametrize('ride_along', all_basic_gens + decimal_gens + array_gens_sample + struct_gens_sample + map_gens_sample, ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_window_ride_along(ride_along):
    assert_gpu_and_cpu_are_equal_sql(
            lambda spark : gen_df(spark, [('a', UniqueLongGen()), ('b', ride_along)]),
            "window_agg_table",
            'select *,'
            ' row_number() over (order by a) as row_num '
            'from window_agg_table ')

@approximate_float
@ignore_order
@pytest.mark.parametrize('preceding', [Window.unboundedPreceding, -4], ids=idfn)
@pytest.mark.parametrize('following', [Window.unboundedFollowing, 3], ids=idfn)
def test_window_range_stddev(preceding, following):
    window_spec_agg = Window.partitionBy("_1").orderBy("_2").rangeBetween(preceding, following)

    def do_it(spark):
        # rangBetween uses the actual value of the column on which we are doing the aggregation
        # which is why we are generating values between LONG_MIN_VALUE - min(preceding) and LONG_MAX_VALUE - max(following)
        # otherwise it will cause an overflow
        gen = LongGen(min_val=-(1 << 63) + 4, max_val=(1 << 63) - 4)
        data_gen = [('_1', RepeatSeqGen(gen, length=20)), ('_2', gen)]
        df = gen_df(spark, data_gen)
        return df.withColumn("standard_dev", f.stddev("_2").over(window_spec_agg)) \
            .selectExpr("standard_dev")

    assert_gpu_and_cpu_are_equal_collect(do_it, conf={ 'spark.rapids.sql.window.range.long.enabled': 'true'})

@approximate_float
@ignore_order
@pytest.mark.parametrize('preceding', [Window.unboundedPreceding, -4], ids=idfn)
@pytest.mark.parametrize('following', [Window.unboundedFollowing, 3], ids=idfn)
def test_window_rows_stddev(preceding, following):
    window_spec_agg = Window.partitionBy("_1").orderBy("_2").rowsBetween(preceding, following)

    def do_it(spark):
        data_gen = [('_1', RepeatSeqGen(IntegerGen(), length=20)), ('_2', DoubleGen())]
        df = gen_df(spark, data_gen)
        return df.withColumn("standard_dev", f.stddev("_2").over(window_spec_agg)) \
            .selectExpr("standard_dev")

    assert_gpu_and_cpu_are_equal_collect(do_it)


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
@ignore_order
def test_unbounded_to_unbounded_window():
    # This is specifically to test a bug that caused overflow issues when calculating
    # the range for some row based queries. The bug applied to more than just
    # unbounded to unbounded, but this is the simplest to write
    assert_gpu_and_cpu_are_equal_collect(lambda spark : spark.range(1024).selectExpr(
        'SUM(id) OVER ()',
        'COUNT(1) OVER ()'))


_nested_gens = array_gens_sample + struct_gens_sample + map_gens_sample + [binary_gen]
exprs_for_nth_first_last = \
    'first(a) OVER (PARTITION BY b ORDER BY c RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), ' \
    'first(a) OVER (PARTITION BY b ORDER BY c RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), ' \
    'first(a) OVER (PARTITION BY b ORDER BY c ROWS  BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), ' \
    'first(a) OVER (PARTITION BY b ORDER BY c ROWS  BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), ' \
    'last (a) OVER (PARTITION BY b ORDER BY c RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), ' \
    'last (a) OVER (PARTITION BY b ORDER BY c RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), ' \
    'last (a) OVER (PARTITION BY b ORDER BY c ROWS  BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), ' \
    'last (a) OVER (PARTITION BY b ORDER BY c ROWS  BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), ' \
    'NTH_VALUE(a, 1) OVER (PARTITION BY b ORDER BY c RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), ' \
    'NTH_VALUE(a, 2) OVER (PARTITION BY b ORDER BY c RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), ' \
    'NTH_VALUE(a, 3) OVER (PARTITION BY b ORDER BY c ROWS  BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), ' \
    'NTH_VALUE(a, 3) OVER (PARTITION BY b ORDER BY c ROWS  BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), '\
    'first(a, true) OVER (PARTITION BY b ORDER BY c), ' \
    'last (a, true) OVER (PARTITION BY b ORDER BY c), ' \
    'last (a, true) OVER (PARTITION BY b ORDER BY c) '
exprs_for_nth_first_last_ignore_nulls = \
    'NTH_VALUE(a, 1) IGNORE NULLS OVER (PARTITION BY b ORDER BY c RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), ' \
    'first(a) IGNORE NULLS OVER (PARTITION BY b ORDER BY c ROWS  BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), ' \
    'last(a) IGNORE NULLS OVER (PARTITION BY b ORDER BY c) '

@pytest.mark.parametrize('data_gen', all_basic_gens_no_null + decimal_gens + _nested_gens, ids=idfn)
@allow_non_gpu(*non_utc_allow)
def test_window_first_last_nth(data_gen):
    assert_gpu_and_cpu_are_equal_sql(
        # Coalesce is to make sure that first and last, which are non-deterministic become deterministic
        lambda spark: three_col_df(spark, data_gen, string_gen, int_gen, num_slices=1).coalesce(1),
        "window_agg_table",
        'SELECT a, b, c, ' + exprs_for_nth_first_last +
        'FROM window_agg_table')

@pytest.mark.skipif(is_before_spark_320(), reason='IGNORE NULLS clause is not supported for FIRST(), LAST() and NTH_VALUE in Spark 3.1.x')
@pytest.mark.parametrize('data_gen', all_basic_gens_no_null + decimal_gens + _nested_gens, ids=idfn)
def test_window_first_last_nth_ignore_nulls(data_gen):
    assert_gpu_and_cpu_are_equal_sql(
        # Coalesce is to make sure that first and last, which are non-deterministic become deterministic
        lambda spark: three_col_df(spark, data_gen, string_gen, int_gen, num_slices=1).coalesce(1),
        "window_agg_table",
        'SELECT a, b, c, ' + exprs_for_nth_first_last_ignore_nulls +
        'FROM window_agg_table')


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
@tz_sensitive_test
@allow_non_gpu(*non_supported_tz_allow)
@ignore_order(local=True)
def test_to_date_with_window_functions():
    """
    This test ensures that date expressions participating alongside window aggregations
    are initialized correctly. (See: https://github.com/NVIDIA/spark-rapids/issues/5984)

    For certain vendor-specific Spark versions, the date expression might be evaluated
    directly in the WindowExec, instead of being projected upstream. For instance,
    the query in this test might produce this plan on CPU:
    ```
      Window [cast(gettimestamp(cast(date_1#1 as string), yyyy-MM-dd, TimestampType, Some(Etc/UTC), false) as date)...]
      +- Sort [id#0L ASC NULLS FIRST, date_2#2 ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(id#0L, 200), ENSURE_REQUIREMENTS, [id=#136]
            +- *(1) Project [date_1#1, id#0L, date_2#2]
    ```

    This might trip up the GPU plan, by incompletely initializing `GpuGetTimeStamp` for `date_1` thus:
    ```
    +- GpuProject [cast(gpugettimestamp(cast(date_1#1 as string), yyyy-MM-dd, null, null, None) as date) AS my_date#6]
    ```

    The correct initialization should have yielded:
    ```
    +- GpuProject [cast(gpugettimestamp(cast(date_1#1 as string), yyyy-MM-dd, yyyy-MM-dd, %Y-%m-%d, None) as date)]
    ```
    """
    assert_gpu_and_cpu_are_equal_sql(
        df_fun=lambda spark: gen_df(spark, [('id', RepeatSeqGen(int_gen, 20)),
                                            ('date_1', DateGen()),
                                            ('date_2', DateGen())]),
        table_name="window_input",
        sql="""
        SELECT TO_DATE( CAST(date_1 AS STRING), 'yyyy-MM-dd' ) AS my_date,
               SUM(1) OVER(PARTITION BY id ORDER BY date_2) AS my_sum
        FROM window_input
        """
    )


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
@ignore_order(local=True)
@approximate_float
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn)
@pytest.mark.parametrize('data_gen', [_grpkey_longs_with_no_nulls,
                                      _grpkey_longs_with_nulls,
                                      _grpkey_longs_with_dates,
                                      _grpkey_longs_with_nullable_dates,
                                      _grpkey_longs_with_decimals,
                                      _grpkey_longs_with_nullable_decimals,
                                      _grpkey_longs_with_nullable_larger_decimals
                                      ], ids=idfn)
@pytest.mark.parametrize('window_spec', ["3 PRECEDING AND -1 FOLLOWING",
                                         "-2 PRECEDING AND 4 FOLLOWING",
                                         "UNBOUNDED PRECEDING AND -1 FOLLOWING",
                                         "-1 PRECEDING AND UNBOUNDED FOLLOWING",
                                         "10 PRECEDING AND -1 FOLLOWING",
                                         "5 PRECEDING AND -2 FOLLOWING"], ids=idfn)
def test_window_aggs_for_negative_rows_partitioned(data_gen, batch_size, window_spec):
    conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
            'spark.rapids.sql.castFloatToDecimal.enabled': True,
            'spark.rapids.sql.window.collectSet.enabled': True,
            'spark.rapids.sql.window.collectList.enabled': True}
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark, data_gen, length=2048),
        "window_agg_table",
        'SELECT '
        ' SUM(c) OVER '
        '   (PARTITION BY a ORDER BY b,c ASC ROWS BETWEEN {window}) AS sum_c_asc, '
        ' MAX(c) OVER '
        '   (PARTITION BY a ORDER BY b DESC, c DESC ROWS BETWEEN {window}) AS max_c_desc, '
        ' MIN(c) OVER '
        '   (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS min_c_asc, '
        ' COUNT(1) OVER '
        '   (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS count_1, '
        ' COUNT(c) OVER '
        '   (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS count_c, '
        ' AVG(c) OVER '
        '   (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS avg_c, '
        ' COLLECT_LIST(c) OVER '
        '   (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window}) AS list_c, '
        ' SORT_ARRAY(COLLECT_SET(c) OVER '
        '   (PARTITION BY a ORDER BY b,c ROWS BETWEEN {window})) AS sorted_set_c '
        'FROM window_agg_table '.format(window=window_spec),
        conf=conf)


def spark_bugs_in_decimal_sorting():
    """
    Checks whether Apache Spark version has a bug in sorting Decimal columns correctly.
    See https://issues.apache.org/jira/browse/SPARK-40089.
    :return: True, if Apache Spark version does not sort Decimal(>20, >2) correctly. False, otherwise.
    """
    v = spark_version()
    return v < "3.1.4" or v < "3.3.1" or v < "3.2.3" or v < "3.4.0"


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
@ignore_order(local=True)
@approximate_float
@pytest.mark.parametrize('batch_size', ['1g'], ids=idfn)
@pytest.mark.parametrize('data_gen', [_grpkey_longs_with_no_nulls,
                                      _grpkey_longs_with_nulls,
                                      _grpkey_longs_with_dates,
                                      _grpkey_longs_with_nullable_dates,
                                      _grpkey_longs_with_decimals,
                                      _grpkey_longs_with_nullable_decimals,
                                      pytest.param(_grpkey_longs_with_nullable_larger_decimals,
                                                   marks=pytest.mark.skipif(
                                                     condition=spark_bugs_in_decimal_sorting(),
                                                     reason='https://github.com/NVIDIA/spark-rapids/issues/7429'))],
                         ids=idfn)
def test_window_aggs_for_negative_rows_unpartitioned(data_gen, batch_size):
    conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
            'spark.rapids.sql.castFloatToDecimal.enabled': True,
            'spark.rapids.sql.window.collectSet.enabled': True,
            'spark.rapids.sql.window.collectList.enabled': True}

    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark, data_gen, length=2048),
        "window_agg_table",
        'SELECT '
        ' SUM(c) OVER '
        '   (ORDER BY b,c,a ROWS BETWEEN 3 PRECEDING AND -1 FOLLOWING) AS sum_c_asc, '
        ' MAX(c) OVER '
        '   (ORDER BY b DESC, c DESC, a DESC ROWS BETWEEN -2 PRECEDING AND 4 FOLLOWING) AS max_c_desc, '
        ' min(c) OVER '
        '   (ORDER BY b,c,a ROWS BETWEEN UNBOUNDED PRECEDING AND -1 FOLLOWING) AS min_c_asc, '
        ' COUNT(1) OVER '
        '   (ORDER BY b,c,a ROWS BETWEEN -1 PRECEDING AND UNBOUNDED FOLLOWING) AS count_1, '
        ' COUNT(c) OVER '
        '   (ORDER BY b,c,a ROWS BETWEEN 10 PRECEDING AND -1 FOLLOWING) AS count_c, '
        ' AVG(c) OVER '
        '   (ORDER BY b,c,a ROWS BETWEEN -1 PRECEDING AND UNBOUNDED FOLLOWING) AS avg_c, '
        ' COLLECT_LIST(c) OVER '
        '   (PARTITION BY a ORDER BY b,c,a ROWS BETWEEN 5 PRECEDING AND -2 FOLLOWING) AS list_c, '
        ' SORT_ARRAY(COLLECT_SET(c) OVER '
        '   (PARTITION BY a ORDER BY b,c,a ROWS BETWEEN 5 PRECEDING AND -2 FOLLOWING)) AS set_c '
        'FROM window_agg_table ',
        conf=conf)


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
@ignore_order(local=True)
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn)
@pytest.mark.parametrize('data_gen', [
    _grpkey_short_with_nulls,
    _grpkey_int_with_nulls,
    _grpkey_long_with_nulls,
    _grpkey_date_with_nulls,
], ids=idfn)
def test_window_aggs_for_batched_finite_row_windows_partitioned(data_gen, batch_size):
    conf = {'spark.rapids.sql.batchSizeBytes': batch_size}
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark, data_gen, length=2048),
        'window_agg_table',
        """
        SELECT
          COUNT(1) OVER (PARTITION BY a ORDER BY b,c ASC
                         ROWS BETWEEN CURRENT ROW AND 100 FOLLOWING) AS count_1_asc,
          COUNT(c) OVER (PARTITION BY a ORDER BY b,c ASC 
                         ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) AS count_c_asc,
          COUNT(c) OVER (PARTITION BY a ORDER BY b,c ASC 
                         ROWS BETWEEN -50 PRECEDING AND 100 FOLLOWING) AS count_c_negative,
          COUNT(1) OVER (PARTITION BY a ORDER BY b,c ASC 
                         ROWS BETWEEN 50 PRECEDING AND -10 FOLLOWING) AS count_1_negative,
          SUM(c) OVER (PARTITION BY a ORDER BY b,c ASC 
                       ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING) AS sum_c_asc, 
          AVG(c) OVER (PARTITION BY a ORDER BY b,c ASC
                       ROWS BETWEEN 10 PRECEDING AND 30 FOLLOWING) AS avg_c_asc,
          MAX(c) OVER (PARTITION BY a ORDER BY b,c DESC
                       ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING) AS max_c_desc,
          MIN(c) OVER (PARTITION BY a ORDER BY b,c ASC
                       ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING) AS min_c_asc,
          LAG(c, 30) OVER (PARTITION BY a ORDER BY b,c ASC) AS lag_c_30_asc,
          LEAD(c, 40) OVER (PARTITION BY a ORDER BY b,c ASC) AS lead_c_40_asc
        FROM window_agg_table
        """,
        validate_execs_in_gpu_plan=['GpuBatchedBoundedWindowExec'],
        conf=conf)


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
@ignore_order(local=True)
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn)
@pytest.mark.parametrize('data_gen', [
    _grpkey_short_with_nulls,
    _grpkey_int_with_nulls,
    _grpkey_long_with_nulls,
    _grpkey_date_with_nulls,
], ids=idfn)
def test_window_aggs_for_batched_finite_row_windows_unpartitioned(data_gen, batch_size):
    conf = {'spark.rapids.sql.batchSizeBytes': batch_size}
    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark, data_gen, length=2048),
        'window_agg_table',
        """
        SELECT
          COUNT(1) OVER (ORDER BY b,c,a ASC
                         ROWS BETWEEN CURRENT ROW AND 100 FOLLOWING) AS count_1_asc,
          COUNT(c) OVER (PARTITION BY a ORDER BY b,c,a ASC 
                         ROWS BETWEEN 100 PRECEDING AND CURRENT ROW) AS count_c_asc,
          COUNT(c) OVER (PARTITION BY a ORDER BY b,c,a ASC 
                         ROWS BETWEEN -50 PRECEDING AND 100 FOLLOWING) AS count_c_negative,
          COUNT(1) OVER (PARTITION BY a ORDER BY b,c,a ASC 
                         ROWS BETWEEN 50 PRECEDING AND -10 FOLLOWING) AS count_1_negative,
          SUM(c) OVER (PARTITION BY a ORDER BY b,c,a ASC 
                       ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING) AS sum_c_asc, 
          AVG(c) OVER (PARTITION BY a ORDER BY b,c,a ASC
                       ROWS BETWEEN 10 PRECEDING AND 30 FOLLOWING) AS avg_c_asc,
          MAX(c) OVER (PARTITION BY a ORDER BY b,c,a DESC
                       ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING) AS max_c_desc,
          MIN(c) OVER (PARTITION BY a ORDER BY b,c,a ASC
                       ROWS BETWEEN 1 PRECEDING AND 3 FOLLOWING) AS min_c_asc,
          LAG(c, 6)  OVER (PARTITION BY a ORDER BY b,c,a ASC) AS lag_c_6,
          LEAD(c,4)  OVER (PARTITION BY a ORDER BY b,c,a ASC) AS lead_c_4
        FROM window_agg_table
        """,
        validate_execs_in_gpu_plan=['GpuBatchedBoundedWindowExec'],
        conf=conf)


@disable_ansi_mode  # https://github.com/NVIDIA/spark-rapids/issues/5114
@ignore_order(local=True)
@pytest.mark.parametrize('data_gen', [_grpkey_int_with_nulls,], ids=idfn)
def test_window_aggs_for_batched_finite_row_windows_fallback(data_gen):
    """
    This test is to verify that batching is disabled for bounded windows if
    the window extents exceed the window-extents specified in the RAPIDS conf.
    """

    # Query with window extent = { 200 PRECEDING, 200 FOLLOWING }.
    query = """
        SELECT
          COUNT(1) OVER (PARTITION BY a ORDER BY b,c ASC
                         ROWS BETWEEN 200 PRECEDING AND 200 FOLLOWING) AS count_1_asc    
        FROM window_agg_table                 
    """

    def get_conf_with_extent(extent):
      return {'spark.rapids.sql.batchSizeBytes': '1000',
              'spark.rapids.sql.window.batched.bounded.row.max': extent}

    def assert_query_runs_on(exec, conf):
        assert_gpu_and_cpu_are_equal_sql(
            lambda spark: gen_df(spark, data_gen, length=2048),
            'window_agg_table',
            query,
            validate_execs_in_gpu_plan=[exec],
            conf=conf)

    # Check that with max window extent set to 100,
    # query runs without batching, i.e. `GpuWindowExec`.
    conf_100 = get_conf_with_extent(100)
    assert_query_runs_on(exec='GpuWindowExec', conf=conf_100)

    # Check that with max window extent set to 200,
    # query runs *with* batching, i.e. `GpuBatchedBoundedWindowExec`.
    conf_200 = get_conf_with_extent(200)
    assert_query_runs_on(exec='GpuBatchedBoundedWindowExec', conf=conf_200)


@pytest.mark.skipif(condition=not (is_spark_350_or_later() or is_databricks133_or_later()),
                    reason="WindowGroupLimit not available for spark.version < 3.5 "
                           "and Databricks version < 13.3")
@ignore_order(local=True)
@approximate_float
@pytest.mark.parametrize('batch_size', ['1k', '1g'], ids=idfn)
@pytest.mark.parametrize('data_gen', [_grpkey_longs_with_no_nulls,
                                      _grpkey_longs_with_nulls,
                                      _grpkey_longs_with_dates,
                                      _grpkey_longs_with_nullable_dates,
                                      _grpkey_longs_with_decimals,
                                      _grpkey_longs_with_nullable_decimals,
                                      pytest.param(_grpkey_longs_with_nullable_larger_decimals,
                                                   marks=pytest.mark.skipif(
                                                       condition=spark_bugs_in_decimal_sorting(),
                                                       reason='https://github.com/NVIDIA/spark-rapids/issues/7429'))
                                      ],
                         ids=idfn)
@pytest.mark.parametrize('rank_clause', [
                            'RANK() OVER (PARTITION BY a ORDER BY b, c) ',
                            'DENSE_RANK() OVER (PARTITION BY a ORDER BY b, c) ',
                            'RANK() OVER (ORDER BY a,b,c) ',
                            'DENSE_RANK() OVER (ORDER BY a,b,c) ',
                            # ROW_NUMBER() on an un-partitioned window does not invoke WindowGroupLimit optimization.
                            'ROW_NUMBER() OVER (PARTITION BY a ORDER BY b,c) ',
])
def test_window_group_limits_for_ranking_functions(data_gen, batch_size, rank_clause):
    """
    This test verifies that window group limits are applied for queries with ranking-function based
    row filters.
    This test covers RANK() and DENSE_RANK(), for window function with and without `PARTITIONED BY`
    clauses.
    """
    conf = {'spark.rapids.sql.batchSizeBytes': batch_size,
            'spark.rapids.sql.castFloatToDecimal.enabled': True}

    query = """
        SELECT * FROM (
          SELECT *, {} AS rnk
          FROM window_agg_table
        )
        WHERE rnk < 3
     """.format(rank_clause)

    assert_gpu_and_cpu_are_equal_sql(
        lambda spark: gen_df(spark, data_gen, length=4096),
        "window_agg_table",
        query,
        conf=conf)


def test_lru_cache_datagen():
    # log cache info at the end of integration tests, not related to window functions
    info = gen_df_help.cache_info()
    warnings.warn("Cache info: {}".format(info))
    gen_df_help.cache_clear()
